import logging
import os
from itertools import cycle
 
import numpy as np
import gin
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchattacks
from timm.utils import ModelEmaV2
from torchmetrics import Accuracy, MeanMetric
from tqdm import tqdm
 
from util.awp import AWP
from util.bypass_bn import disable_running_stats, enable_running_stats, set_bn_momentum
from util.pgd_attack import PGD
from util.trades_attack import TRADES
from util.utils import save_ckpt
import pandas as pd

import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import random
# Record label distribution in abnormal batches
columns = ['Batch ID'] + [f'Class_{i}' for i in range(100)] + ['WGrad Norm']
batch_records = pd.DataFrame(columns=columns)
entropy_records = pd.DataFrame(columns=columns)

# Initialize data structures for recording classifications
batch_classification_counts = pd.DataFrame(columns=['Batch ID', 'Correct-Correct', 'Correct-Incorrect', 'Incorrect-Correct', 'Incorrect-Incorrect'])

# Dictionaries for abnormal batch classifications
abnormal_classification_counts = pd.DataFrame(columns=['Batch ID', 'Correct-Correct', 'Correct-Incorrect', 'Incorrect-Correct', 'Incorrect-Incorrect'])
class_based_counts = {i: pd.DataFrame(columns=['Batch ID'] + [f'Class_{j}' for j in range(100)]) for i in range(1, 5)}
add_noise_batches = 0

def save_random_state():
    return torch.get_rng_state(), np.random.get_state(), random.getstate()

def restore_random_state(torch_state, numpy_state, random_state):
    torch.set_rng_state(torch_state)
    np.random.set_state(numpy_state)
    random.setstate(random_state)

def get_random_orthogonal_directions(model):
    torch_state, numpy_state = save_random_state()
    
    params = [p.data for p in model.parameters() if p.requires_grad]
    dir1 = [torch.randn_like(p) for p in params]
    dir2 = [torch.randn_like(p) for p in params]

    # Orthogonalize dir2 with respect to dir1
    for d1, d2 in zip(dir1, dir2):
        proj = torch.dot(d1.view(-1), d2.view(-1)) / torch.dot(d1.view(-1), d1.view(-1))
        d2.add_(-proj * d1)
    print(dir1)
    print(dir2)
    restore_random_state(torch_state, numpy_state)
    return dir1, dir2

def get_weights(net):
    return [p.data for p in net.parameters()]

def normalize_direction(d, w):
    d.mul_(w.norm() / (d.norm() + 1e-10))

def normalize_directions_for_weights(direction, weights, ignore='biasbn'):
    assert(len(direction) == len(weights))
    for d, w in zip(direction, weights):
        if d.dim() <= 1:
            if ignore == 'biasbn':
                d.fill_(0)
            else:
                d.copy_(w)
        else:
            normalize_direction(d, w)

def perturb_and_evaluate(model, grad_direction, alpha_range, cls_loss_fn, kl_loss_fn, img, img_adv, label, adv_beta, device):
    print(len(grad_direction))
    normalize_directions_for_weights(grad_direction, get_weights(model))
    print(len(grad_direction))
    with torch.no_grad():
        loss_surface = []
        for alpha in alpha_range:
            # Apply perturbation
            perturb_model = copy.deepcopy(model)
            for (param, grad) in zip(perturb_model.parameters(), grad_direction):
                param.data += alpha * grad.to(device)
            logits_natural = perturb_model(img)
            logits_adv = perturb_model(img_adv)
            # Compute loss
            loss_natural = cls_loss_fn(logits_natural, label)
            loss_robust = kl_loss_fn(
                F.log_softmax(logits_adv, dim=1),
                F.log_softmax(logits_natural, dim=1)
            )
            full_loss = loss_natural + adv_beta * loss_robust
            loss_surface.append(full_loss.item())

            del perturb_model

    return torch.tensor(loss_surface)



@gin.configurable
class AdvTrainer:
    def __init__(
        self,
        device,
        hparam,
        adv_train_mode,
        model_ema_decay,
        train_attacker=None,
        eval_attacker=None,
        adv_beta=6.0,
        use_ema=True,
        aux_loader=None,
    ):
        assert adv_train_mode in ["Normal", "PGD-AT", "TRADES"]
        self.device = device
        self.hparam = hparam
        self.num_classes = self.hparam["n_class"]
        self.solution = hparam.get("solution")
        self.epoch_loss = MeanMetric().to(self.device)
        self.clean_acc = Accuracy(task="multiclass", num_classes=self.num_classes).to(
            self.device
        )
        self.adv_acc = Accuracy(task="multiclass", num_classes=self.num_classes).to(
            self.device
        )
 
        self.train_attacker = train_attacker
        self.eval_attacker = eval_attacker
        self.adv_beta = adv_beta  # Only for TRADES
        self.adv_train_mode = adv_train_mode
 
        self.use_ema = use_ema
 
        self.model_ema_decay = model_ema_decay
 
        self.aux_loader = None
        if aux_loader is not None:
            self.aux_loader = cycle(aux_loader)
        self.alpha_range = torch.linspace(-1.0, 1.0, steps=100)

        logging.info("Evaluate with EMA: %s" % (use_ema))

    def train(
        self,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        ckpt_path,
        dir1,
        dir2,
        writer=None,
        epoch=50,
        start_epoch=0,
        runs_path="./",
    ):
        best_eval_clean_acc = 0
        best_eval_adv_acc = 0
        best_safe_eval_adv_acc = 0
        model = model.to(self.device)
        self.dir1 = dir1
        self.dir2 = dir2
        self.model_ema = ModelEmaV2(model, decay=self.model_ema_decay)

 
        try:
            for e in range(start_epoch, epoch):#epoch):
                logging.info("Epoch: %s, Train Start" % (e))
                epoch_loss, train_clean_acc, train_adv_acc, R, ce_var, kl_var, covar, cesum, klsum, avg_w_norm, wsgcs = self._train_one_epoch(
                    train_loader, val_loader, model, optimizer, scheduler, e * len(train_loader), writer, e, runs_path, self.dir1, self.dir2, self.solution
                )
                
                lr = scheduler._get_lr(e * len(train_loader))[0]
 
                if self.use_ema:
                    eval_clean_acc, eval_adv_acc, FOSC, SGCS = self.eval(
                        self.model_ema.module, val_loader
                    )
                else:
                    eval_clean_acc, eval_adv_acc, FOSC, SGCS = self.eval(model, val_loader)
 
                logging.info("Train Loss: %s" % (epoch_loss))
                logging.info("Train CE (clean) Loss: %s" %(cesum))
                logging.info("Train KL (adv*beta) Loss: %s" %(klsum))
                logging.info("Train Clean Acc: %s" % (train_clean_acc))
                logging.info("Train Adv Acc: %s" % (train_adv_acc))
                logging.info("Eval Clean Acc: %s" % (eval_clean_acc))
                logging.info("Eval Adv Acc: %s" % (eval_adv_acc))
                logging.info("Train CE (clean) Var: %s" %(ce_var))
                logging.info("Train KL (adv*beta) Var: %s" %(kl_var))
                logging.info("Train CEKL Covariance: %s" %(covar))
                logging.info("Train CEKL ReCoeff: %s" %(R))
                logging.info("FOSC: %s" % (FOSC))
                logging.info("SGCS: %s" % (SGCS))
                logging.info("Weight Grad Norm: %s" %(avg_w_norm))
                logging.info("WSGCS: %s" %(wsgcs))
 
                if eval_clean_acc > best_eval_clean_acc:
                    best_eval_clean_acc = eval_clean_acc
                    save_ckpt(
                        os.path.join(ckpt_path, "best_clean_score.pt"),
                        model,
                        self.model_ema.module,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epoch=e,
                    )
                if eval_adv_acc > best_eval_adv_acc:
                    best_eval_adv_acc = eval_adv_acc
                    save_ckpt(
                        os.path.join(ckpt_path, "best_adv_score.pt"),
                        model,
                        self.model_ema.module,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epoch=e,
                    )
                if eval_adv_acc > best_safe_eval_adv_acc and FOSC <= 0.001:
                    best_safe_eval_adv_acc = eval_adv_acc
                    save_ckpt(
                        os.path.join(ckpt_path, "best_safe_adv_score.pt"),
                        model,
                        self.model_ema.module,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epoch=e,
                    )
                # if e <= 200: #in [50, 100, 150]:
                #     save_ckpt(
                #         os.path.join(ckpt_path, "%s.pt" % (e)),
                #         model,
                #         self.model_ema.module,
                #         optimizer=optimizer,
                #         scheduler=scheduler,
                #         epoch=e,
                #     )
 
                writer.add_scalar("Score/acc", eval_clean_acc, e)
                writer.add_scalar("Score/adv_acc", eval_adv_acc, e)
                writer.add_scalar("Score/train_clean_acc", train_clean_acc, e)
                writer.add_scalar("Score/train_adv_acc", train_adv_acc, e)
                writer.add_scalar("Score/gap", train_clean_acc - train_adv_acc, e)

                writer.add_scalar("Score/R", R, e)
                writer.add_scalar("Score/covariance", covar, e)
                writer.add_scalar("Score/CE_variance", ce_var, e)
                writer.add_scalar("Score/KL_variance", kl_var, e)
                writer.add_scalar("Loss/CE_loss", cesum, e)
                writer.add_scalar("Loss/KL_loss", klsum, e)

                writer.add_scalar("Loss/train", epoch_loss, e)
                writer.add_scalar("lr", lr, e)
                writer.add_scalar("SGCS", SGCS, e)
                writer.add_scalar("FOSC", FOSC, e)
                writer.add_scalar("W_Grad_Norm", avg_w_norm, e)
                writer.add_scalar("WSGCS", wsgcs, e)
                writer.flush()
                global add_noise_batches
                if FOSC > 0.001 and self.solution:
                    add_noise_batches = 10  # Number of batches to add noise
            print(runs_path)


 
        except KeyboardInterrupt:
            logging.info("Keyboard Interrupt Received, Please Wait....")
 
        writer.add_hparams(
            self.hparam,
            {
                "hparam/top1_score": best_eval_clean_acc,
                "hparam/adv_score": best_eval_adv_acc,
                "hparam/adv_score_final": eval_adv_acc,
            },
            run_name="record",
        )
        save_ckpt(
            os.path.join(ckpt_path, "final_%s.pt" % (e)),
            model,
            self.model_ema.module,
            optimizer=optimizer,
            scheduler=scheduler,
            epoch=e,
        )
 
    def eval(self, model, val_loader):
        logging.info("Start Validatoin")
        model.eval()
        
        SGCS_total = []
        FOSC_total = []
        for data in tqdm(val_loader):
            img, label = data
            img = img.to(self.device)
            label = label.to(self.device)
            with torch.no_grad():
                logit = model(img)
            self.clean_acc.update(logit, label)
           
            img_adv, FOSC, SGCS = self.eval_attacker.attack(model, img, label)
            FOSC_total.extend(FOSC)
            SGCS_total.extend(SGCS)
           
            with torch.no_grad():
                logit_adv = model(img_adv)
            self.adv_acc.update(logit_adv, label)
 
        clean_acc = self.clean_acc.compute().item()
        adv_acc = self.adv_acc.compute().item()
 
        self.clean_acc.reset()
        self.adv_acc.reset()
        FOSC = float(sum(FOSC_total) / len(FOSC_total))
        SGCS = float(sum(SGCS_total) / len(SGCS_total))

        return clean_acc, adv_acc, FOSC, SGCS
 
    def _train_one_epoch(self, train_loader,val_loader, model, optimizer, scheduler, global_step, writer, epoch, runs_path, dir1, dir2, solution):
        cls_loss_fn = torch.nn.CrossEntropyLoss()
        kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
        model.train()
        # Lists to store losses for each batch
        ce_losses = []
        kl_losses = []
        avg_gradient_norms = []
        sign_gradients = []
        correct_clean_preds = 0  # Counter for correct clean predictions
        total_clean_preds = 0  # Counter for total clean predictions
        for idx, data in enumerate(tqdm(train_loader)):
            if global_step + idx == 0 and self.aux_loader is not None:
                set_bn_momentum(model, momentum=1.0)
            elif global_step + idx == 1 and self.aux_loader is not None:
                set_bn_momentum(model, momentum=0.01)
 
            img, label = data
            global add_noise_batches
            if add_noise_batches > 0 and solution:
                add_noise_batches -= 1
                img = img + torch.randn_like(img) * 0.1
                img = np.clip(img, 0.0, 1.0)

            # When we want to use additoinal DDPM data.
            if self.aux_loader is not None:
                img_aux, label_aux = next(self.aux_loader)
                img = torch.vstack([img, img_aux])
                label = torch.hstack([label, label_aux])
 
            img = img.to(self.device, non_blocking=True)
            label = label.to(self.device, non_blocking=True)
 
            clean_label = label.clone()
 
 
            # Generate Adversarial Examples
            if self.adv_train_mode != "Normal":
                img_adv = self.train_attacker.attack(model, img, label)
           
            # Train the model
            model.train()
            optimizer.zero_grad()
 
            if self.adv_train_mode == "Normal":
                if isinstance(optimizer, AWP):
                    # First Step
                    logits = model(img)
                    loss = cls_loss_fn(logits, label)
                    loss.backward()
                    optimizer.first_step(zero_grad=True)
 
                    # Second Step
                    disable_running_stats(model)
                    logits = model(img)
                    cls_loss_fn(logits, label).backward()
                    optimizer.second_step(zero_grad=True)
                    enable_running_stats(model)
 
                else:
                    logits = model(img)
                    loss = cls_loss_fn(logits, label)
                    loss.backward()
                    optimizer.step()
 
            elif self.adv_train_mode == "PGD-AT":
                if isinstance(optimizer, AWP):
                    # First Step
                    logits_adv = model(img_adv)
                    loss = cls_loss_fn(logits_adv, label)
                    loss.backward()
                    optimizer.first_step(zero_grad=True)
 
                    # Second Step
                    disable_running_stats(model)
                    logits = model(img_adv)
                    cls_loss_fn(logits, label).backward()
                    optimizer.second_step(zero_grad=True)
                    enable_running_stats(model)
                else:
                    logits_adv = model(img_adv)
                    loss = cls_loss_fn(logits_adv, label)
                    loss.backward()
                    optimizer.step()
 
            elif self.adv_train_mode == "TRADES":
                if isinstance(optimizer, AWP):
                    # First Step
                    logits = model(torch.cat([img, img_adv]))
                    logits_natural = logits[: img.shape[0]]
                    logits_adv = logits[img.shape[0] :]
 
                    loss_natural = cls_loss_fn(logits_natural, label)
                    loss_robust = kl_loss_fn(
                        F.log_softmax(logits_adv, dim=1),
                        F.log_softmax(logits_natural, dim=1),
                    )
                    loss = loss_natural + self.adv_beta * loss_robust
                    loss.backward()
                    optimizer.first_step(zero_grad=True)
 
                    # Second Step
                    disable_running_stats(model)
                    logits = model(torch.cat([img, img_adv]))
                    logits_natural = logits[: img.shape[0]]
                    logits_adv = logits[img.shape[0] :]
 
                    loss_natural = cls_loss_fn(logits_natural, label)
                    loss_robust = kl_loss_fn(
                        F.log_softmax(logits_adv, dim=1),
                        F.log_softmax(logits_natural, dim=1),
                    )
                    (loss_natural + self.adv_beta * loss_robust).backward()
                    optimizer.second_step(zero_grad=True)
                    enable_running_stats(model)
 
                else:
                    logits = model(torch.cat([img, img_adv]))
                    logits_natural = logits[: img.shape[0]]
                    logits_adv = logits[img.shape[0]:]
 
                    loss_natural = cls_loss_fn(logits_natural, label)
                    loss_robust = kl_loss_fn(
                        F.log_softmax(logits_adv, dim=1),
                        F.log_softmax(logits_natural, dim=1),
                    )
                    loss = loss_natural + self.adv_beta * loss_robust
                    ce_losses.append(loss_natural)
                    kl_losses.append(self.adv_beta * loss_robust)

                    loss.backward(retain_graph=True)
                    grad_direction = [param.grad.clone() for param in model.parameters() if param.requires_grad]

                    total_norm = 0
                    for param in model.parameters():
                        if param.grad is not None:
                            param_norm = param.grad.data.norm(2)
                            total_norm += param_norm.item() ** 2
                    total_norm = total_norm ** 0.5
                    predictions = logits_natural.argmax(dim=1)
                    adv_predictions = logits_adv.argmax(dim=1)
                    correct_clean = predictions.eq(label).float()
                    correct_clean_preds += correct_clean.sum().item()
                    total_clean_preds += label.size(0)
                    correct_adv = adv_predictions.eq(label).float()

                    writer.add_scalar(f"epoch_{epoch}_Wgradnorm", total_norm, global_step + idx)
                    writer.add_scalar(f"epoch_{epoch}_ce", loss_natural, global_step + idx)
                    writer.add_scalar(f"epoch_{epoch}_kl", loss_robust, global_step + idx)
                    
                    # Calculate gradients for natural and adversarial losses individually without affecting optimization
                    gradients_natural = torch.autograd.grad(loss_natural, model.parameters(), retain_graph=True, create_graph=False)
                    gradients_adv = torch.autograd.grad(loss_robust, model.parameters(), create_graph=False)

                    # Calculate cosine similarity between gradients
                    grad_natural_flat = torch.cat([g.view(-1) for g in gradients_natural if g is not None])
                    grad_adv_flat = torch.cat([g.view(-1) for g in gradients_adv if g is not None])
                    cos_sim = F.cosine_similarity(grad_natural_flat.unsqueeze(0), grad_adv_flat.unsqueeze(0), dim=1)
                    norm_grad_natural = grad_natural_flat.norm()
                    norm_grad_adv = grad_adv_flat.norm()
                    # to edit
                    writer.add_scalar(f"epoch_{epoch}_ce_norm", norm_grad_natural, global_step + idx)
                    writer.add_scalar(f"epoch_{epoch}_kl_norm", norm_grad_adv, global_step + idx)
                     
                    writer.add_scalar(f"epoch_{epoch}_cosine_similarity_ce_kl", cos_sim.item(), global_step + idx)

                    avg_gradient_norms.append(total_norm)

                    aggregated_sign_grads = []
                    if True:
                        for param in model.parameters():
                            if param.grad is not None:
                                # Flatten and sign the gradient
                                sign_grad = torch.sign(param.grad).view(-1)
                                aggregated_sign_grads.append(sign_grad)
                        
                        # Concatenate all sign gradients for the current batch into a single tensor
                        if aggregated_sign_grads:
                            sign_gradients.append(torch.cat(aggregated_sign_grads))
                    
                    optimizer.step()
                    # with torch.no_grad():
                    torch_state, numpy_state, random_state = save_random_state()
                    optimizer.zero_grad()
                    logits = model(torch.cat([img, img_adv]))
                    logits_natural = logits[: img.shape[0]]
                    logits_adv = logits[img.shape[0]:]

                    loss_natural = cls_loss_fn(logits_natural, label)
                    loss_robust = kl_loss_fn(
                        F.log_softmax(logits_adv, dim=1),
                        F.log_softmax(logits_natural, dim=1),
                    )
                    loss2 = loss_natural + self.adv_beta * loss_robust
                    loss2.backward()

                    # Calculate gradient direction after optimizer step
                    grad_direction_after = [param.grad.clone() for param in model.parameters() if param.requires_grad]

                    # Calculate cosine similarity between the gradient directions before and after the optimizer step
                    grad_before_flat = torch.cat([g.view(-1) for g in grad_direction if g is not None])
                    grad_after_flat = torch.cat([g.view(-1) for g in grad_direction_after if g is not None])
                    cos_sim_grads = F.cosine_similarity(grad_before_flat.unsqueeze(0), grad_after_flat.unsqueeze(0), dim=1).item()

                    writer.add_scalar(f"epoch_{epoch}_grad_cosine_similarity", cos_sim_grads, global_step + idx)
                    optimizer.zero_grad()
                    restore_random_state(torch_state, numpy_state, random_state)


            scheduler.step(global_step + idx)
 
            if self.adv_train_mode != "Normal":
                self.adv_acc.update(logits_adv, clean_label)
            self.model_ema.update(model)
            self.epoch_loss.update(loss)
        if(len(avg_gradient_norms) != 0):
            epoch_avg_grad_norm = sum(avg_gradient_norms) / len(avg_gradient_norms)
        else:
            epoch_avg_grad_norm = 0
        epoch_loss = self.epoch_loss.compute().item()
        epoch_robust_acc = (
            0 if self.adv_train_mode == "Normal" else self.adv_acc.compute().item()
        )
        ce_losses = [tensor.item() for tensor in ce_losses]
        kl_losses = [tensor.item() for tensor in kl_losses]
        correlation_coefficient = np.corrcoef(ce_losses, kl_losses)[0, 1]
        variance_ce = np.var(ce_losses)
        variance_kl = np.var(kl_losses)
        covariance = np.cov(ce_losses, kl_losses)[0, 1]
        wsgcs = 0
        if len(sign_gradients) > 0:  # Ensure there are gradients to process
            n_combinations = 0
            for i in range(len(sign_gradients)):
                for j in range(i + 1, len(sign_gradients)):
                    cos_sim = F.cosine_similarity(sign_gradients[i].unsqueeze(0), sign_gradients[j].unsqueeze(0), dim=1)
                    wsgcs += cos_sim
                    n_combinations += 1
            wsgcs = wsgcs / n_combinations if n_combinations > 0 else torch.tensor(0.0)

        else:
            print('No gradients collected for WSGCS calculation.')
        clean_train_accuracy = correct_clean_preds / total_clean_preds
        self.epoch_loss.reset()
        self.adv_acc.reset()
        return epoch_loss, clean_train_accuracy, epoch_robust_acc, correlation_coefficient, variance_ce, variance_kl, covariance, np.average(ce_losses), np.average(kl_losses), epoch_avg_grad_norm, wsgcs.item()